import math
import os
from time import time
import threadpool
import requests
import urllib.parse


class Download:
    """
    多线程流式下载工具类,支持断点下载
    """

    ua = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/108.0.0.0 Safari/537.36"
    tempDir = "temp"
    saveDir = "download"

    def __init__(
            self,
            url: str,
            fileName: str = None,
            partSize: int = 1024 * 1024 * 10,
            threadNum: int = 3,
            limitTime=10000,
    ) -> None:
        """
        初始化
        @url：文件链接
        @fileName：文件名(默认从链接中获取)
        @perPartSize：单线程下载大小(默认10MB)
        @threadNum：线程数(默认3)，并不是越多越好，一是设备自身宽带存在上限，二是部分站点会限制IP的连接数
        @limitTime：线程1%进度限制时间(ms,默认10000) ，非正值则不限时，超时则重启该连接(针对部分站点会出现下载越来越慢的情况)
        """
        self.url = url
        self.partSize = partSize
        self.threadNum = threadNum
        self.limitTime = limitTime
        if not fileName:
            self.fileName = self.__getFileName__()
        else:
            self.fileName = fileName
        self.__markDir__()

    def __getFileName__(self) -> str:
        """
        从链接中获取文件名
        """
        url = urllib.parse.unquote(self.url)
        return url.split("?")[0].split("/")[-1]

    def __markDir__(self) -> None:
        """
        创建缓存和保存目录
        """
        if not os.path.isdir(self.tempDir):
            os.mkdir(self.tempDir)
        if not os.path.isdir(self.saveDir):
            os.mkdir(self.saveDir)

    def __getTotalSize__(self) -> int:
        """
        获取文件总大小
        """
        headers = {"user-agent": self.ua}
        res = requests.head(url=self.url, headers=headers)
        size = None

        if res.status_code == 200 and "Content-Length" in res.headers:
            size = int(res.headers["Content-Length"])
            print(f"文件大小: {size / 1024 / 1024} MB")
        else:
            print(f"status_code：{res.status_code}")
        return size

    def __preDownloadPart__(self, index: int, totalSize: int) -> int:
        """
        检查分块下载进度 返回剩余大小
        """
        fileName = f"temp/{self.fileName}_{index}"
        if os.path.isfile(fileName):
            stat = os.stat(fileName)
            return totalSize - stat.st_size  # 剩余大小
        else:
            return totalSize

    def __downloadPart__(self, index: int, rangeStart: int, rangeEnd: int) -> None:
        """
        下载分块
        @index：分块序号
        @start：分块起始位置
        @end：分块结束位置
        """
        indexTip = f"{index + 1}/{self.partNum}"  # 分块位置提示
        totalSize = rangeEnd - rangeStart + 1  # 分块总大小

        restart = True  # 重新下载标志
        while restart:
            needDownSize = self.__preDownloadPart__(index, totalSize)  # 分块剩余大小

            if needDownSize == 0:  # 分块已下载
                print(f"[{indexTip}]已完成")
                return
            if totalSize != needDownSize:  # 分块已存在,追加模式
                file = open(f"temp/{self.fileName}_{index}", mode="ba")
            else:  # 分块未存在，新建模式
                file = open(f"temp/{self.fileName}_{index}", mode="bw")

            currentSize = 0  # 已经下载大小
            progress = 0  # 下载进度

            headers = {
                "user-agent": self.ua,
                "Range": f"bytes={rangeEnd - needDownSize + 1}-{rangeEnd}",  # 设置下载范围
            }

            startTime = int(time() * 1000)  # 下载开始时间
            req = requests.get(url=self.url, stream=True, headers=headers)  # 流式下载
            if req.status_code != 206:
                print(f"[{indexTip}][{req.status_code}]服务异常")
                return

            for content in req.iter_content(chunk_size=2048):  # 读取并保存下载数据
                if content:
                    file.write(content)
                    currentSize += 2048  # 更新已下载大小
                    if currentSize < needDownSize:  # 未完成下载
                        newProgress = int(currentSize * 100 / needDownSize)  # 下载进度
                        if progress != newProgress:
                            progress = newProgress
                            divTime = int(time() * 1000) - startTime  # 1%进度花费时间
                            if self.limitTime and divTime > self.limitTime:  # 超时,重新下载
                                print(f"[{divTime}ms][{indexTip}][{progress}%]超时")
                                file.close()
                                break
                            else:
                                startTime = int(time() * 1000)
                                # print(f"[{divTime}ms][{indexTip}][{progress}%]下载中")
                                # print(f"\r[{divTime}ms][{indexTip}] {progress}%[{0}->{100}]{progress}下载中", end="")
                                # 打印下载进度（使下载过程可视化）
                                a = "*" * int(progress / 100 * 50)
                                b = "." * int((1 - progress / 100) * 50)
                                print(f"\r [{divTime}ms]\t[{indexTip}][{a}->{b}]{progress}", end="")
            restart = False
            print(f"\n[{indexTip}][{rangeStart}-{rangeEnd}]下载完成")

    def __checkParts__(self, partList: list) -> bool:
        """
        检查全部分块是否已下载
        """
        for part in partList:
            if self.__preDownloadPart__(part["index"], part["totalSize"]) != 0:
                return False
        return True

    def __checkFile__(self) -> bool:
        """
        检查目标文件是否已下载
        """
        fileName = f"download/{self.fileName}"
        return os.path.isfile(fileName)

    def __mergeParts__(self):
        """
        合并分块
        """
        fileName = f"download/{self.fileName}"
        targetFile = open(fileName, mode="bw")
        for index in range(self.partNum):
            partFile = f"temp/{self.fileName}_{index}"
            file = open(partFile, mode="br")
            targetFile.write(file.read())
            file.close()
            print(f"[{partFile}]合并成功")
        targetFile.close()

    def __deleteParts__(self):
        """
        删除缓存分块
        """
        for index in range(self.partNum):
            partName = f"temp/{self.fileName}_{index}"
            if os.path.isfile(partName):
                os.remove(partName)
                print(f"删除[{partName}]成功")

    def start(self):
        """
        开始下载
        """
        totalSize = self.__getTotalSize__()  # 文件总大小
        if not totalSize:
            print("文件不支持断点下载，任务退出")
            return

        self.partNum = math.ceil((totalSize / self.partSize))  # 计算分块数量

        if self.__checkFile__():
            self.__deleteParts__()
            print("文件已下载，任务退出")
            return

        restart = True  # 重新执行标志
        while restart:
            pool = threadpool.ThreadPool(self.threadNum)  # 创建线程池
            partList = []  # 分块列表(包含序号和大小)
            argsList = []  # 任务参数列表

            for i in range(self.partNum):  # 构建参数
                rangeStart = i * self.partSize
                rangeEnd = (
                    (totalSize - 1)
                    if (i + 1 == self.partNum)
                    else ((i + 1) * self.partSize - 1)
                )
                args = ([i, rangeStart, rangeEnd], None)
                argsList.append(args)
                partList.append({"index": i, "totalSize": args[0][2] - args[0][1] + 1})

            reqs = threadpool.makeRequests(self.__downloadPart__, argsList)  # 构建任务队列
            for req in reqs:  # 提交任务
                pool.putRequest(req)
            pool.wait()  # 等待线程结束

            if self.__checkParts__(partList):  # 检查分块状态
                self.__mergeParts__()
                self.__deleteParts__()
                restart = False
                print("下载结束")
            else:
                print("分块未完全下载，请重新执行程序")


if __name__ == "__main__":
    urls = [
        "http://images.cocodataset.org/zips/train2017.zip",
        "http://images.cocodataset.org/zips/val2017.zip",
        "http://images.cocodataset.org/zips/test2017.zip",
        "http://images.cocodataset.org/annotations/annotations_trainval2017.zip",
        "http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip",
        "http://images.cocodataset.org/annotations/panoptic_annotations_trainval2017.zip",
        "http://images.cocodataset.org/annotations/image_info_test2017.zip"
    ]
    for url in urls:
        Download(url, threadNum=2, limitTime=3000).start()
